/*
 * Copyright 2010-2021 JetBrains s.r.o. and Kotlin Programming Language contributors.
 * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
 */

package org.jetbrains.kotlin.fir.extensions

import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolNamesProvider
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProvider
import org.jetbrains.kotlin.fir.resolve.providers.FirSymbolProviderInternals
import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name

/**
 * This provider is needed because we need to have ability to disable FirExtensionDeclarationsSymbolProvider during
 *   phase of annotations for plugins resolution. At this stage predicateBasedProvider is not indexed, so it will return
 *   empty results for all requests
 *
 * This is also legal, because plugins can not generate annotation classes which can influence other plugins or this plugin itself
 */
open class FirSwitchableExtensionDeclarationsSymbolProvider protected constructor(
    private val delegate: FirExtensionDeclarationsSymbolProvider
) : FirSymbolProvider(delegate.session) {
    companion object {
        fun createIfNeeded(session: FirSession): FirSwitchableExtensionDeclarationsSymbolProvider? =
            FirExtensionDeclarationsSymbolProvider.createIfNeeded(session)?.let { FirSwitchableExtensionDeclarationsSymbolProvider(it) }
    }

    protected open var disabled: Boolean = false

    override val symbolNamesProvider: FirSymbolNamesProvider = object : FirSymbolNamesProvider() {
        override fun getPackageNames(): Set<String>? =
            if (disabled) null else delegate.symbolNamesProvider.getPackageNames()

        override val hasSpecificClassifierPackageNamesComputation: Boolean
            get() = delegate.symbolNamesProvider.hasSpecificClassifierPackageNamesComputation

        override fun getPackageNamesWithTopLevelClassifiers(): Set<String>? =
            if (disabled) null else delegate.symbolNamesProvider.getPackageNamesWithTopLevelClassifiers()

        override val hasSpecificCallablePackageNamesComputation: Boolean
            get() = delegate.symbolNamesProvider.hasSpecificCallablePackageNamesComputation

        override fun getPackageNamesWithTopLevelCallables(): Set<String>? =
            if (disabled) null else delegate.symbolNamesProvider.getPackageNamesWithTopLevelCallables()

        override fun getTopLevelClassifierNamesInPackage(packageFqName: FqName): Set<Name>? =
            if (disabled) null else delegate.symbolNamesProvider.getTopLevelClassifierNamesInPackage(packageFqName)

        override fun getTopLevelCallableNamesInPackage(packageFqName: FqName): Set<Name>? =
            if (disabled) null else delegate.symbolNamesProvider.getTopLevelCallableNamesInPackage(packageFqName)
    }

    override fun getClassLikeSymbolByClassId(classId: ClassId): FirClassLikeSymbol<*>? {
        if (disabled) return null
        return delegate.getClassLikeSymbolByClassId(classId)
    }

    @FirSymbolProviderInternals
    override fun getTopLevelCallableSymbolsTo(destination: MutableList<FirCallableSymbol<*>>, packageFqName: FqName, name: Name) {
        if (disabled) return
        delegate.getTopLevelCallableSymbolsTo(destination, packageFqName, name)
    }

    @FirSymbolProviderInternals
    override fun getTopLevelFunctionSymbolsTo(destination: MutableList<FirNamedFunctionSymbol>, packageFqName: FqName, name: Name) {
        if (disabled) return
        delegate.getTopLevelFunctionSymbolsTo(destination, packageFqName, name)
    }

    @FirSymbolProviderInternals
    override fun getTopLevelPropertySymbolsTo(destination: MutableList<FirPropertySymbol>, packageFqName: FqName, name: Name) {
        if (disabled) return
        delegate.getTopLevelPropertySymbolsTo(destination, packageFqName, name)
    }

    override fun hasPackage(fqName: FqName): Boolean {
        if (disabled) return false
        return delegate.hasPackage(fqName)
    }

    @FirSymbolProviderInternals
    fun disable() {
        require(!disabled) {
            "Attempt to disable already disabled ${FirSwitchableExtensionDeclarationsSymbolProvider::class}"
        }

        disabled = true
    }

    @FirSymbolProviderInternals
    fun enable() {
        require(disabled) {
            "Attempt to enable already enabled ${FirSwitchableExtensionDeclarationsSymbolProvider::class}"
        }

        disabled = false
    }

    @FirSymbolProviderInternals
    internal fun isDisabled(): Boolean = disabled
}

val FirSession.generatedDeclarationsSymbolProvider: FirSwitchableExtensionDeclarationsSymbolProvider? by FirSession.nullableSessionComponentAccessor()

@FirSymbolProviderInternals
fun FirSession.withGeneratedDeclarationsSymbolProviderDisabled(action: () -> Unit) {
    val enabledProvider = generatedDeclarationsSymbolProvider?.takeUnless { it.isDisabled() }

    enabledProvider?.disable()

    try {
        action()
    } finally {
        enabledProvider?.enable()
    }
}
